import java.util.*;

public class Leetcode1305 {

    public List<Integer> getAllElements(TreeNode root1, TreeNode root2) {
        List<Integer> res = new ArrayList<>();
        mid(res, root1);
        mid(res, root2);

        Collections.sort(res);

        return res;
    }


    private void mid(List<Integer> res, TreeNode tree) {
        if (tree != null) {
            mid(res, tree.left);
            res.add(tree.val);
            mid(res, tree.right);
        }
    }
}
